{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "CTC.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "1A6gJZcOVq-h",
        "colab_type": "code",
        "outputId": "f7031d2b-a78b-425b-dacf-a96d1362f1fc",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        }
      },
      "source": [
        "!pip install -qU git+https://github.com/harvardnlp/pytorch-struct@fixalign\n",
        "!pip install -qU git+https://github.com/harvardnlp/genbmm\n",
        "!pip install -q matplotlib"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "  Building wheel for torch-struct (setup.py) ... \u001b[?25l\u001b[?25hdone\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "r070BEdwVzHs",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import torch\n",
        "import torch_struct\n",
        "import matplotlib.pyplot as plt"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "zxRC_exrbTR4",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Character Vocab\n",
        "vocab = [\"a\", \"b\", \"c\", \"d\", \"e\", \"_\"]\n",
        "v_dict = { a:i for i, a in enumerate(vocab)}\n",
        "L = len(vocab)\n",
        "\n",
        "# Char sequence\n",
        "letters = \"a _ b _ c _ d _ e\"\n",
        "t = len(letters.split())\n",
        "\n",
        "# Frame sequence\n",
        "frames = [\"a\", \"a\", \"a\", \"_\", \"b\", \"b\", \"c\", \"c\", \"c\", \"c\", \"_\", \"_\", \"d\", \"e\"]\n",
        "\n",
        "# Constants\n",
        "T, B = len(frames), 5\n",
        "D1, MATCH, D2 = 0, 1, 2\n",
        "\n",
        "def show(m):\n",
        "    plt.yticks(torch.arange(len(letters.split())), letters.split())\n",
        "    plt.xticks(torch.arange(T), [str(frames[x.item()]) for x in torch.arange(T)])\n",
        "    plt.imshow(m.cpu().detach())"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qXQ1s1HvV2Nq",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Gold alignment. \n",
        "gold = torch.zeros(B, t).long()\n",
        "for i, l in enumerate(letters.split()):\n",
        "    gold[:, i] = v_dict[l]\n",
        "gold = gold[:, None, :].expand(B, T, t)\n",
        "\n",
        "# Inputs (boost true frames a bit)\n",
        "logits = torch.rand(B, T, L)\n",
        "for i in range(T):\n",
        "    logits[:, i, v_dict[frames[i]]] += 1 "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "il4pW6M9YOKP",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Construct the alignment problem from CTC\n",
        "\n",
        "# Log-Potentials\n",
        "log_potentials = torch.zeros(B, T, t, 3)\n",
        "\n",
        "# Match gold to logits. \n",
        "match = torch.gather(logits, 2, gold)\n",
        "\n",
        "# CTC Rules:\n",
        "\n",
        "# Never allowed to fully skip regular characters (little t)\n",
        "log_potentials[:, :, ::2,  D2] = -1e5\n",
        "\n",
        "# Free to skip _ characters (little t)\n",
        "log_potentials[:, :, 1::2, D2] = 0\n",
        "\n",
        "# First match with character is the logit. \n",
        "log_potentials[:, :, :, MATCH] = match\n",
        "\n",
        "# Additional match with character is the logit.\n",
        "log_potentials[:, :, :, D1] = match\n",
        "\n",
        "# (This might be slightly off)\n",
        "\n",
        "log_potentials = log_potentials.transpose(1, 2).cuda()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Hz5I8cnLfpHE",
        "colab_type": "code",
        "outputId": "560683d9-263c-4c82-acb5-2df1a0e1d4f0",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 263
        }
      },
      "source": [
        "# Show input scores\n",
        "show(match.transpose(1,2).exp()[0])"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAD2CAYAAAD/C81vAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAANU0lEQVR4nO3dX2zdd33G8efxcZzYSZvGaTsYlIZA\nlZQwLVJc0dBWRYSpqqCDSSBQV3XTmHxBJaZIvUH0gptISJPQdjEqXAmtAi4qgtoBHagoVWHpRBP3\nj0NDIGVttoitfxKHpIk3Oz7+7MIuROlxfZL8Ps4n9fslVXV0Tt/51nGenv58co4jQgCAunou9gEA\nAG+NoQaA4hhqACiOoQaA4hhqACiOoQaA4nozoq2VK6N3cDAjLUn6k8HX0tqSdGymldo/cnBVal+t\n3PNrup3cn85rL1uW15YUU1Opfa9YntqPydzzx6r+1H7P5OnUfnugL609OTGu05On3Om2lKHuHRzU\nu/9ue0ZakrTnrvvT2pL0vZOXp/a/+We3pvZnVq9M7fccPZHanzk6ntbuecfVaW1Jmj70X6n91vrr\nUvszL+aef+rGD6b2+194NbV/Yssfp7XHdv3jvLdx6QMAimOoAaA4hhoAimOoAaA4hhoAimOoAaA4\nhhoAimOoAaC4roba9iO2n7a93/Zw9qEAAH/Q7Z9M/JuIGLfdL2mv7e9FxNHMgwEAZnV76eOLtsck\n/VzSNZLe9OdYbQ/bHrU9OnPyVJNnBIAlbcGhtv0RSR+TtDUi/lTSs5JWnH2/iBiJiKGIGOpZlfta\nEwCwlHTziHq1pGMRMWF7o6Qbk88EADhDN0P9Y0m9tg9I+qpmL38AABbJgt9MjIhJSbcvwlkAAB3w\nPGoAKI6hBoDiGGoAKI6hBoDiGGoAKI6hBoDiUt6FfMWRaV33z3kvBfLxr92W1pak9tFjqf07xp5L\n7e/cnvv56f/fqdT+zG//L63t10+mtWd/Aqfmo9VK7Q8+PpDa33NoJrW/7oGrUvuXP3U4rd06Nf/v\nKx5RA0BxDDUAFMdQA0BxDDUAFMdQA0BxDDUAFMdQA0BxDDUAFMdQA0BxDDUAFMdQA0BxjQ217WHb\no7ZHp9oTTWUBYMlrbKgjYiQihiJiqK+V+8IuALCUcOkDAIpjqAGgOIYaAIpjqAGgOIYaAIpb8K24\nbK+VtKvDTdsiIu/9tgAAkroY6rkx3rwIZwEAdMClDwAojqEGgOIYagAobsFr1OcjJic1c/DFjLQk\nqWf9tWltSdK7rkzNP/yFD6b2b/qHp1L7T7x8XWrfIx9Ka1/2i9fS2pI0vfm9qf3JNSm/ZX9v5qZf\npfbf9Yncr52Z3nZu/9jv8uLt+c/OI2oAKI6hBoDiGGoAKI6hBoDiGGoAKI6hBoDiGGoAKI6hBoDi\nFhxq2+tsP78YhwEAvBmPqAGguG6Hutf2d2wfsL3TNm8zDgCLpNuh3iDp6xFxvaQTkr5w9h1sD9se\ntT16OiabPCMALGndDvXhiHhy7uNvS7r57DtExEhEDEXE0DIvb+yAALDUdTvUscCPAQBJuh3q99je\nOvfxnZJ2J50HAHCWbof615LusX1A0hpJ9+cdCQBwpm7e3PaQpI35RwEAdMLzqAGgOIYaAIpjqAGg\nOIYaAIpjqAGgOIYaAIpb8Ol558PLl6vnfesz0rP98eNpbUnSS4dT838x9j+p/Z3bb0vtX/7ieGq/\n/cJTefG1g3ltSb2/eSm137dpQ2p/zZNrUvt7Dp1O7a97wKn9njVX5MWnWvP/vHk/KwCgCQw1ABTH\nUANAcQw1ABTHUANAcQw1ABTHUANAcQw1ABTHUANAcQw1ABTHUANAcY0Nte1h26O2R6faE01lAWDJ\na2yoI2IkIoYiYqivNdBUFgCWPC59AEBxDDUAFMdQA0BxDDUAFLfgO7zYXitpV4ebtkXE0eaPBAA4\n04JDPTfGmxfhLACADrj0AQDFMdQAUBxDDQDFMdQAUNyC30w8HzPLW5p47+qMtCRpYLqd1pakePVI\nav+HQ9ek9nVLbv70O/J+bSXpxI1b09qP7Pj7tLYkvTC9KrX/V4/dkNr/wGdfT+2//5X9qf2e/hWp\n/VM3b0hrz+xeNu9tPKIGgOIYagAojqEGgOIYagAojqEGgOIYagAojqEGgOIYagAojqEGgOIYagAo\nruuhtn237X22x2x/K/NQAIA/6Oq1PmxvknSfpA9HxBHbgx3uMyxpWJKW91/R6CEBYCnr9hH1RyV9\nNyKOSFJEjJ99h4gYiYihiBha1reyyTMCwJLGNWoAKK7boX5c0mfm3uhWnS59AABydHWNOiL2294h\n6ae225KelfTXmQcDAMzq+o0DIuJBSQ8mngUA0AHXqAGgOIYaAIpjqAGgOIYaAIpjqAGgOIYaAIrr\n+ul556Jnsq2Bl45npCVJPnEyrS1JbrVS+58YPZza37n9A6n9ZS/n/dpK0hX/9mxa+/P/+udpbUlq\njx9L7W/cdCK1v/qhidT+nkObUvvrHnBqf+XYb9PaPROn578t7WcFADSCoQaA4hhqACiOoQaA4hhq\nACiOoQaA4hhqACiOoQaA4hhqACiOoQaA4hhqACiusaG2PWx71PboVDv39QIAYClpbKgjYiQihiJi\nqK810FQWAJY8Ln0AQHEMNQAUx1ADQHEMNQAUt+A7vNheK2lXh5u2RcTR5o8EADjTgkM9N8abF+Es\nAIAOuPQBAMUx1ABQHEMNAMUx1ABQ3ILfTDwv09PSK0dS0pI0M3U6rS1JPSv7U/uPfvKG1P7yQ/tS\n+167Jre/rC+tPfGh96W1JanveO7X5tSKVmr/d7e3U/vvviVnct7wyX96LLX/w8/fmtaOY/N/bnhE\nDQDFMdQAUBxDDQDFMdQAUBxDDQDFMdQAUBxDDQDFMdQAUNw5D7Xtr9i+N+MwAIA34xE1ABTX1VDb\n/rLtg7Z3S9qQfCYAwBm6eYeXLZI+p9k3D+iV9Iykpzvcb1jSsCSt6FnV7CkBYAnr5hH1LZIejoiJ\niDgh6fud7hQRIxExFBFDfT0rGj0kACxlXKMGgOK6GeqfSfqU7X7bl0m6I/lMAIAzdPPmts/YfkjS\nmKRXJe1NPxUA4Pe6ehXviNghaUfyWQAAHXCNGgCKY6gBoDiGGgCKY6gBoDiGGgCKY6gBoLiunp53\n7tVe6Y+uTElLUs/48bS2JLWPHkvt3/HvuU9F37n9ttR+/4vjqf14+ZW09sBT/5HWlqT2eO7XTt+m\n3NdEu+JHrdT+wUPTqf1/uWdban/54ZfT2p6a/3PDI2oAKI6hBoDiGGoAKI6hBoDiGGoAKI6hBoDi\nGGoAKI6hBoDiGGoAKI6hBoDiGGoAKK6xobY9bHvU9uhUe6KpLAAseY0NdUSMRMRQRAz1tQaaygLA\nkselDwAojqEGgOIYagAojqEGgOIWfIcX22sl7epw07aIONr8kQAAZ1pwqOfGePMinAUA0AGXPgCg\nOIYaAIpjqAGgOIYaAIpb8JuJ5+X0tPTfr6SkJen1WzektSVp4NHcJ7M8+uH1qf1+vZDanzl5KrU/\nefsNae3W5ExaW5J6nzie2n994+rUvv+2ndp//8rTqX2NHUzNv3bnlrT29A/65r2NR9QAUBxDDQDF\nMdQAUBxDDQDFMdQAUBxDDQDFMdQAUBxDDQDFMdQAUFxXQ237Ltt7bD9n+xu2W9kHAwDMWnCobV8v\n6bOSboqIzZLakv4y+2AAgFndvNbHNklbJO21LUn9kl49+062hyUNS9KKnpUNHhEAlrZuhtqSHoyI\nL73VnSJiRNKIJK3uvSoaOBsAQN1do94l6dO2r5Yk24O2r809FgDgDQsOdUT8UtJ9kh6zvU/STyS9\nM/tgAIBZXb0edUQ8JOmh5LMAADrgedQAUBxDDQDFMdQAUBxDDQDFMdQAUBxDDQDFMdQAUJwjmv/T\n3rZfk/Sf5/CPXCnpSOMHyW/Tp0+fflPtayPiqk43pAz1ubI9GhFDl1qbPn369BejzaUPACiOoQaA\n4qoM9cgl2qZPnz799HaJa9QAgPlVeUQNAJgHQ71E2V5n+/lLtQ80xfZXbN97sc/xVhhqACjuog61\n7UdsP217/9yb49JfXL22v2P7gO2dtgcupb7tu23vsz1m+1tNtum/vdn+su2DtndL2pDQv8v2HtvP\n2f6G7dYFBSPiov0laXDu7/2Snpe0lv6ife7XSQpJN839+JuS7r2E+pskHZR05Zm/FvQXp38p/yVp\ni6RfSBqQdLmk3zT8tXm9pB9IWjb3469LuvtCmhf70scXbY9J+rmkayRdR39RHY6IJ+c+/rakmy+h\n/kclfTcijkhSRIw32Kb/9naLpIcjYiIiTkj6fsP9bZr9j8Fe28/N/Xj9hQS7es/EDLY/IuljkrZG\nxITtJyStoL+ozn5uZtPP1czuAxVZ0oMR8aWmghfzEfVqScfmRm6jpBvpL7r32N469/GdknZfQv3H\nJX3G9lpJsj3YYJv+29vPJH3Kdr/tyyTd0XB/l6RP275amv3c2772QoIXc6h/rNlvNh2Q9FXNXj6g\nv7h+LemeuX+HNZLuv1T6EbFf0g5JP527/PS1ptr0394i4hlJD0kak/QjSXsb7v9S0n2SHrO9T9JP\nJL3zQpr8yUQAKO6iXaPGW5v7X9ZdHW7aFhFHF/s8wBuyvzb52n8zHlEDQHEX++l5AIAFMNQAUBxD\nDQDFMdQAUBxDDQDFMdQAUNz/AyNIEYe39J50AAAAAElFTkSuQmCC\n",
            "text/plain": [
              "<Figure size 432x288 with 1 Axes>"
            ]
          },
          "metadata": {
            "tags": []
          }
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "xqDiA1mqfu6Q",
        "colab_type": "code",
        "outputId": "365d7416-b8ec-4dc8-8276-9d9eab0ffbf4",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 263
        }
      },
      "source": [
        "# Find best alignment\n",
        "dist = torch_struct.AlignmentCRF(log_potentials)\n",
        "show(dist.argmax[0])"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAD2CAYAAAD/C81vAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAJvUlEQVR4nO3dQailZR3H8d+/JmgEkxyTDLIhEBMX\nBc4iK0GcNi0Cg6KwsGhxF4WubCG6cBO0amnkIghzIRmKLYpiosRAmtEcUwelTbQTp6CF0CKeFveo\n1/HcuceZ973nf+58PjB4r/fMn2fOzHw9vue571NjjADQ1/vWvQAAzk+oAZoTaoDmhBqgOaEGaE6o\nAZo7NMfQuqpGjs4xedtNz843ez9s+PKBmYwxatm/rzn2UdexGjk1+di3LP+lbI4NXz4wk91C7dIH\nQHNCDdCcUAM0J9QAzQk1QHNCDdCcUAM0J9QAza0U6qp6oqqeraqXqmpr7kUB8LZVv4X8u2OMf1XV\n4SQnq+pXY4yzcy4MgG2rhvruqvrK4uOPJ7kuyTtCvXilvf1q+9qplgfAnpc+qurWJF9McvMY49NJ\n/prkg+c+bozx0Bjj2BjjWD4y+ToBLlmrXKO+Ism/xxhvVNWnknx25jUBsMMqof5tkkNVdSbJj5I8\nM++SANhpz2vUY4z/JvnSPqwFgCXsowZoTqgBmhNqgOaEGqA5oQZoTqgBmpvnFPKq6YceJHM/O445\nh43kFHKADSXUAM0JNUBzQg3QnFADNCfUAM0JNUBzQg3QnFADNCfUAM0JNUBzex7Ftaqq2kqyNdU8\nALa5KdM6uCkTsISbMgFsKKEGaE6oAZoTaoDmhBqguT2351XVkSQnlnzp+Bjj7PRLAmAn2/PWwfY8\nYAnb8wA2lFADNCfUAM1Ndq8P3oO5ryHPfA18+VW06bjEDu/kFTVAc0IN0JxQAzQn1ADNCTVAc0IN\n0JxQAzQn1ADN7RnqqjpaVS/ux2IAeDevqAGaWzXUh6rqkao6U1WPVdVls64KgLesGurrkzw4xrgh\nyX+SfO/cB1TVVlWdqqpTUy4Q4FK358EBVXU0yVNjjGsXn9+W5O4xxu3n+TkODlgnN2WCjXSxBwec\n+1dfiAH2yaqhvraqbl58fEeSp2daDwDnWDXUryT5flWdSfLhJD+Zb0kA7ORw24PINWrYSA63BdhQ\nQg3QnFADNCfUAM0JNUBzQg3Q3KF1L4AZzL59bu7dlzbowU5eUQM0J9QAzQk1QHNCDdCcUAM0J9QA\nzQk1QHNCDdCcUAM0J9QAzQk1QHOT3eujqraSbE01D4BtzkzkArgpE8zBmYkAG0qoAZoTaoDmhBqg\nuT13fVTVkSQnlnzp+Bjj7PRLAmAnuz64AHZ9wBzs+gDYUEIN0JxQAzQn1ADNTXavDy4lM7/ZN+N7\nlcvfquFNnp6evKIGaE6oAZoTaoDmhBqgOaEGaE6oAZoTaoDmhBqgOaEGaE6oAZpbOdRVdWdVvVBV\np6vq4TkXBcDbVrrXR1XdmOT+JJ8bY7xeVVcuecxWkq2J1wdwyVvphJequivJR8cY96001AkvXAw3\nZVobT896OeEFYEOtGuo/JPna4qDbLLv0AcA8Vj7ctqq+neQHSf6X5K9jjO+c57EufXDhXPpYG0/P\neu126cMp5PQj1Gvj6Vkv16gBNpRQAzQn1ADNCTVAc0IN0JxQAzS30r0+YF/NuEfM9rM9zLyxdu7t\nkQf199craoDmhBqgOaEGaE6oAZoTaoDmhBqgOaEGaE6oAZoTaoDmhBqgOaEGaG6ye31U1VaSranm\nAbDNmYnA29yUaa2cmQiwoYQaoDmhBmhOqAGa23PXR1UdSXJiyZeOjzHOTr8kAHay6wN4m10fa2XX\nB8CGEmqA5oQaoDmhBmhusnt9AAfApr/ZN/c2hjW9W+kVNUBzQg3QnFADNCfUAM0JNUBzQg3QnFAD\nNCfUAM2951BX1QNVdc8ciwHg3byiBmhupVBX1X1V9WpVPZ3k+pnXBMAOq5zwclOSbyT5zOLxzyV5\ndsnjtpJsTb1AgEvdKjdluiXJ42OMN5Kkqp5c9qAxxkNJHlo8xgkvABNxjRqguVVC/VSS26vqcFVd\nnuTLM68JgB32vPQxxniuqh5NcjrJa0lOzr4qAN7iFHLg4NjwgwOcQg6woYQaoDmhBmhOqAGaE2qA\n5oQaoLlVvoUcYDPMvjN45v15u/CKGqA5oQZoTqgBmhNqgOaEGqA5oQZoTqgBmhNqgOaEGqA5oQZo\nTqgBmpvsXh9VtZVka6p5AGxzZiJwgGz2TZmcmQiwoYQaoDmhBmhOqAGa23PXR1UdSXJiyZeOjzHO\nTr8kAHay6wM4QOz6AGANhBqgOaEGaE6oAZqb7F4fAOs375t9s75XeWz3L3lFDdCcUAM0J9QAzQk1\nQHNCDdCcUAM0J9QAzQk1QHNCDdDcSqGuqm9V1V+q6vmq+mlVvX/uhQGwbc9QV9UNSb6e5PNjjM8k\n+V+Sb869MAC2rXKvj+NJbkpysqqS5HCS1859UFVtJdmadHUA7H3CS1XdleRjY4x7Vx7qhBfgIJr5\npkzj1IWf8HIiyVer6uokqaorq+oTU64PgN3tGeoxxstJ7k/yu6p6Icnvk1wz98IA2OZwW4BVNb70\nAcAaCTVAc0IN0JxQAzQn1ADNCTVAc0IN0Nwq9/q4EK8n+cd7ePxVi58zhzlnm2+++ZfS/KW7nCea\nnez6Hd+zfMPLe1VVp8YYxzZttvnmm2/+fsx26QOgOaEGaK5LqB/a0Nnmm2+++bPPbnGNGoDddXlF\nDcAuhPoSVVVHq+rFTZ0PU6mqB6rqnnWv43yEGqC5tYa6qp6oqmer6qXF4bjm769DVfVIVZ2pqseq\n6rJNml9Vd1bVC1V1uqoennK2+QdbVd1XVa9W1dNJrp9h/req6i9V9XxV/bSq3n9RA8cYa/uR5MrF\nPw8neTHJEfP37bk/mu3zKj6/+PxnSe7ZoPk3Jnk1yVU7fy/M35/5m/wjyU1J/pbksiQfSvL3if9s\n3pDk10k+sPj8wSR3XszMdV/6uLuqTid5JsnHk1xn/r765xjjz4uPf5HkCxs0/7YkvxxjvJ4kY4x/\nTTjb/IPtliSPjzHeGGP8J8mTE88/nu3/GJysqucXn3/yYgbOda+PPVXVrUm+mOTmMcYbVfXHJB80\nf1+duzdz6r2ac8+HjirJz8cY9041cJ2vqK9I8u9F5D6V5LPm77trq+rmxcd3JHl6g+b/IcnXqupI\nklTVlRPONv9geyrJ7VV1uKouT/LlieefSPLVqro62X7uq2rXGy6tYp2h/m2232w6k+RH2b58YP7+\neiXJ9xe/hg8n+cmmzB9jvJTkh0n+tLj89OOpZpt/sI0xnkvyaJLTSX6T5OTE819Ocn+S31XVC0l+\nn+Sai5npOxMBmlvbNWrOb/G/rCeWfOn4GOPsfq8H3jT3n01/9t/NK2qA5ta9PQ+APQg1QHNCDdCc\nUAM0J9QAzQk1QHP/BwKWOQXFkW6QAAAAAElFTkSuQmCC\n",
            "text/plain": [
              "<Figure size 432x288 with 1 Axes>"
            ]
          },
          "metadata": {
            "tags": []
          }
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "wdMZ89-ehpxq",
        "colab_type": "code",
        "outputId": "277c5d37-e013-4f66-b350-33447562d448",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 281
        }
      },
      "source": [
        "# Find marginals (see uncertainty from randomness)\n",
        "show(dist.marginals[0])"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAD2CAYAAAD/C81vAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAALEklEQVR4nO3dT6im5XkG8Ot2RpOZaGJmVEjaqARS\nlSw6oAttYgmZtFBKWgsJCWkqpYtDaSBQcBN0kY2QRekyIWeRIkkWkpQEu2hIOiERA+L/MZpJpJQW\noVDraDXFdlBzd3GOZjp+ej6Z9znzjPP7wWG+833vXNxn5nDNO8/5nvet7g4A8zrvTA8AwBtT1ACT\nU9QAk1PUAJNT1ACTU9QAk9s7IrQuqc6VI5K3HHpwXHaS/M/Y+PzL4PwXB+cDY3R3rXq+RryPuq6r\nzgOLx77qv1Z+Kct5bGx8PjU4/z8G5780OB/OVa9X1JY+ACanqAEmp6gBJqeoASanqAEmp6gBJqeo\nASanqAEmt1ZRV9V3q+rBqnq8qjZGDwXAr627hfwvuvuZqtqX5P6q+vvuPj5yMAC2rFvUn6+qP9l+\n/L4kH0jy/4p6+0x762z78qXGA2DHpY+q+kiSjyW5obt/O8nDSd5+6nHdvdnd13X3dbl08TkBzlnr\nrFG/K8mz3f1CVV2d5PrBMwFwknWK+ntJ9lbVsSRfSnLv2JEAONmOa9TdfSLJH+zCLACs4H3UAJNT\n1ACTU9QAk1PUAJNT1ACTU9QAkxtzF/I6r5O3LZ77ivPPPzgsO0kOvOP3hubfdP1dQ/Pv/afnhub/\n/KWXh+afGJoO83IXcoCzlKIGmJyiBpicogaYnKIGmJyiBpicogaYnKIGmJyiBpicogaYnKIGmNyO\nt+JaV1VtJNlYKg+ALYsVdXdvJtlMXrkoEwBLsPQBMDlFDTA5RQ0wOUUNMDlFDTC5Hd/1UVUHkxxZ\n8dLh7j6+/EgAnGzHot4u40O7MAsAK1j6AJicogaYnKIGmFx1L7/bu6rGbiHfMzQ95104Nn/fgbH5\n7/3rsfl/9rWx+Q8dG5f9oxPjspPkl2Pj8/LgfM6s7q5VzzujBpicogaYnKIGmJyiBpicogaYnKIG\nmJyiBpicogaY3I5FXVVXVtVjuzEMAK/ljBpgcusW9d6q+mZVHauqb1fV/qFTAfCqdYv6qiRf7u5r\nkjyf5K9OPaCqNqrqgap6YMkBAc516xb1k939k+3H30jy4VMP6O7N7r6uu69bbDoA1i7qU6+GN/bq\neAC8at2ivryqbth+/Jkk9wyaB4BTrFvUv0jyuao6luTdSb4ybiQATrbOzW3/NcnV40cBYBXvowaY\nnKIGmJyiBpicogaYnKIGmJyiBphcdS+/ybCqzvKdi3vGpu+pofl7+4qh+RdffM3Q/P17nhuW/dx/\nPzwsO0me/98Xhua/1L8ams+Z1d0ry8EZNcDkFDXA5BQ1wOQUNcDkFDXA5BQ1wOQUNcDkFDXA5BQ1\nwOQUNcDkFDXA5Ha8Fde6qmojycZSeQBsWayou3szyWbyVrgoE8A8LH0ATE5RA0xOUQNMTlEDTG7H\nHyZW1cEkR1a8dLi7jy8/EgAn27Got8v40C7MAsAKlj4AJqeoASanqAEmp6gBJlfdy+/2toX8DNsz\nOP7g2PzLfndc9h+/c1x2kvzh3WPzf/7k2Py/OTE2/6mx8Tnbi6e7a9XzzqgBJqeoASanqAEmp6gB\nJqeoASanqAEmp6gBJqeoASanqAEmp6gBJrd2UVfVzVX1aFUdraqvjxwKgF/b8cYBSVJVH0xyW5Lf\n6e6nq+rAimM2kmwsPB/AOW+tok7y0STf6u6nk6S7nzn1gO7eTLKZuCgTwJKsUQNMbt2i/mGST27f\n6Darlj4AGGOtpY/ufryqbk/y46p6OcnDSf585GAAbFl3jTrdfUeSOwbOAsAK1qgBJqeoASanqAEm\np6gBJqeoASanqAEmV93L7/a2hZzTs2dYctVFw7KT5ILz3z80/zcuGbvX7Pd/8/6h+c8e/eXQ/DtP\n/Gpo/mjdXaued0YNMDlFDTA5RQ0wOUUNMDlFDTA5RQ0wOUUNMDlFDTA5RQ0wOUUNMDlFDTC5tW/F\ntZOq2kiysVQeAFsWK+ru3kyymbgoE8CSLH0ATE5RA0xOUQNMTlEDTG7HHyZW1cEkR1a8dLi7jy8/\nEgAn27Got8v40C7MAsAKlj4AJqeoASanqAEmp6gBJlfdy+/2toWcaQ0+NTlv/9j8tx0Ym18vjs2/\n4Max+b/1jrH59/3d2PzurlXPO6MGmJyiBpicogaYnKIGmJyiBpicogaYnKIGmJyiBpjcmy7qqvpi\nVd0yYhgAXssZNcDk1irqqrq1qp6oqnuSXDV4JgBOss4dXq5N8uls3Txgb5KHkjy44riNJBtLDwhw\nrtuxqJPcmOQ73f1CklTVXasO6u7NJJvbx7goE8BCrFEDTG6dor47yU1Vta+qLkry8cEzAXCSdW5u\n+1BV3ZnkaJKnktw/fCoAXrXOGnW6+/Yktw+eBYAVrFEDTE5RA0xOUQNMTlEDTE5RA0xOUQNMrrqX\n3+1tCznnrrHnPnvOu3Bo/sUXvjQ0/+CLB4fmX7j/1qH5Dx3/y6H53V2rnndGDTA5RQ0wOUUNMDlF\nDTA5RQ0wOUUNMDlFDTA5RQ0wOUUNMDlFDTA5RQ0wubVuxbWOqtpIsrFUHgBbFivq7t5Mspm4KBPA\nkix9AExOUQNMTlEDTE5RA0xuxx8mVtXBJEdWvHS4u48vPxIAJ9uxqLfL+NAuzALACpY+ACanqAEm\np6gBJqeoASZX3cvv9raFHFjp/LHxF1w6Nv+994zL/vc/Sk78tGvVa86oASanqAEmp6gBJqeoASan\nqAEmp6gBJqeoASanqAEmp6gBJrdWUVfVZ6vqvqp6pKq+WlV7Rg8GwJYdi7qqrknyqSQf6u5DSV5O\n8qejBwNgy443DkhyOMm1Se6vqiTZl+SpUw+qqo0kG4tOB8BaRV1J7ujuL7zRQd29mWQzcVEmgCWt\ns0Z9JMknquqyJKmqA1V1xdixAHjFjkXd3T9LcluS71fVo0l+kOQ9owcDYMs6Sx/p7juT3Dl4FgBW\n8D5qgMkpaoDJKWqAySlqgMkpaoDJKWqAySlqgMlV9/K7vavqP5P825v4LZckeXrxQcZny5cvX/5S\n2Vd096WrXhhS1G9WVT3Q3dedbdny5cuXvxvZlj4AJqeoASY3S1FvnqXZ8uXLlz88e4o1agBe3yxn\n1AC8DkV9jqqqK6vqsbM1H5ZSVV+sqlvO9BxvRFEDTO6MFnVVfbeqHqyqx7dvjit/d+2tqm9W1bGq\n+nZV7T+b8qvq5qp6tKqOVtXXl8yW/9ZWVbdW1RNVdU+Sqwbkf7aq7quqR6rqq1W157QCu/uMfSQ5\nsP3rviSPJTkof9f+7K9M0kk+tP3515LcchblfzDJE0kuOfnvQv7u5J/NH0muTfLTJPuTvDPJPy/8\nvXlNkn9Icv72519OcvPpZJ7ppY/PV9XRJPcmeV+SD8jfVU9290+2H38jyYfPovyPJvlWdz+dJN39\nzILZ8t/abkzyne5+obufT3LXwvmHs/WPwf1V9cj25+8/ncC17pk4QlV9JMnHktzQ3S9U1Y+SvF3+\nrjr1vZlLv1dzdD7MqJLc0d1fWCrwTJ5RvyvJs9sld3WS6+Xvusur6obtx59Jcs9ZlP/DJJ+sqoNJ\nUlUHFsyW/9Z2d5KbqmpfVV2U5OML5x9J8omquizZ+rOvqitOJ/BMFvX3svXDpmNJvpSt5QP5u+sX\nST63/TW8O8lXzpb87n48ye1Jfry9/PS3S2XLf2vr7oeS3JnkaJJ/THL/wvk/S3Jbku9X1aNJfpDk\nPaeTaWciwOTO2Bo1b2z7v6xHVrx0uLuP7/Y88IrR35u+91/LGTXA5M702/MA2IGiBpicogaYnKIG\nmJyiBpicogaY3P8Bqipi36WFMjoAAAAASUVORK5CYII=\n",
            "text/plain": [
              "<Figure size 432x288 with 1 Axes>"
            ]
          },
          "metadata": {
            "tags": []
          }
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "sWLCBtu1hr5S",
        "colab_type": "code",
        "outputId": "9bd8cf58-f7e6-438b-b8ab-d33f381c2b15",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 53
        }
      },
      "source": [
        "# Print the log sum exp of the data (for training)\n",
        "print(dist.partition)"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "tensor([28.5271, 28.8374, 28.3810, 27.0438, 26.3049], device='cuda:0',\n",
            "       grad_fn=<SqueezeBackward1>)\n"
          ],
          "name": "stdout"
        }
      ]
    }
  ]
}